"""
This module includes some utility functions.

The methods most typically used are the sigencode and sigdecode functions
to be used with :func:`~ecdsa.keys.SigningKey.sign` and
:func:`~ecdsa.keys.VerifyingKey.verify`
respectively. See the :func:`sigencode_strings`, :func:`sigdecode_string`,
:func:`sigencode_der`, :func:`sigencode_strings_canonize`,
:func:`sigencode_string_canonize`, :func:`sigencode_der_canonize`,
:func:`sigdecode_strings`, :func:`sigdecode_string`, and
:func:`sigdecode_der` functions.
"""

from __future__ import division

import os
import math
import binascii
import sys
from hashlib import sha256
from six import PY2, int2byte, next
from . import der
from ._compat import normalise_bytes


# RFC5480:
#   The "unrestricted" algorithm identifier is:
#     id-ecPublicKey OBJECT IDENTIFIER ::= {
#       iso(1) member-body(2) us(840) ansi-X9-62(10045) keyType(2) 1 }

oid_ecPublicKey = (1, 2, 840, 10045, 2, 1)
encoded_oid_ecPublicKey = der.encode_oid(*oid_ecPublicKey)

# RFC5480:
# The ECDH algorithm uses the following object identifier:
#      id-ecDH OBJECT IDENTIFIER ::= {
#        iso(1) identified-organization(3) certicom(132) schemes(1)
#        ecdh(12) }

oid_ecDH = (1, 3, 132, 1, 12)

# RFC5480:
# The ECMQV algorithm uses the following object identifier:
#      id-ecMQV OBJECT IDENTIFIER ::= {
#        iso(1) identified-organization(3) certicom(132) schemes(1)
#        ecmqv(13) }

oid_ecMQV = (1, 3, 132, 1, 13)

if sys.version_info >= (3,):  # pragma: no branch

    def entropy_to_bits(ent_256):
        """Convert a bytestring to string of 0's and 1's"""
        return bin(int.from_bytes(ent_256, "big"))[2:].zfill(len(ent_256) * 8)

else:

    def entropy_to_bits(ent_256):
        """Convert a bytestring to string of 0's and 1's"""
        return "".join(bin(ord(x))[2:].zfill(8) for x in ent_256)


if sys.version_info < (2, 7):  # pragma: no branch
    # Can't add a method to a built-in type so we are stuck with this
    def bit_length(x):
        return len(bin(x)) - 2

else:

    def bit_length(x):
        return x.bit_length() or 1


def orderlen(order):
    return (1 + len("%x" % order)) // 2  # bytes


def randrange(order, entropy=None):
    """Return a random integer k such that 1 <= k < order, uniformly
    distributed across that range. Worst case should be a mean of 2 loops at
    (2**k)+2.

    Note that this function is not declared to be forwards-compatible: we may
    change the behavior in future releases. The entropy= argument (which
    should get a callable that behaves like os.urandom) can be used to
    achieve stability within a given release (for repeatable unit tests), but
    should not be used as a long-term-compatible key generation algorithm.
    """
    assert order > 1
    if entropy is None:
        entropy = os.urandom
    upper_2 = bit_length(order - 2)
    upper_256 = upper_2 // 8 + 1
    while True:  # I don't think this needs a counter with bit-wise randrange
        ent_256 = entropy(upper_256)
        ent_2 = entropy_to_bits(ent_256)
        rand_num = int(ent_2[:upper_2], base=2) + 1
        if 0 < rand_num < order:
            return rand_num


class PRNG:
    # this returns a callable which, when invoked with an integer N, will
    # return N pseudorandom bytes. Note: this is a short-term PRNG, meant
    # primarily for the needs of randrange_from_seed__trytryagain(), which
    # only needs to run it a few times per seed. It does not provide
    # protection against state compromise (forward security).
    def __init__(self, seed):
        self.generator = self.block_generator(seed)

    def __call__(self, numbytes):
        a = [next(self.generator) for i in range(numbytes)]

        if PY2:  # pragma: no branch
            return "".join(a)
        else:
            return bytes(a)

    def block_generator(self, seed):
        counter = 0
        while True:
            for byte in sha256(
                ("prng-%d-%s" % (counter, seed)).encode()
            ).digest():
                yield byte
            counter += 1


def randrange_from_seed__overshoot_modulo(seed, order):
    # hash the data, then turn the digest into a number in [1,order).
    #
    # We use David-Sarah Hopwood's suggestion: turn it into a number that's
    # sufficiently larger than the group order, then modulo it down to fit.
    # This should give adequate (but not perfect) uniformity, and simple
    # code. There are other choices: try-try-again is the main one.
    base = PRNG(seed)(2 * orderlen(order))
    number = (int(binascii.hexlify(base), 16) % (order - 1)) + 1
    assert 1 <= number < order, (1, number, order)
    return number


def lsb_of_ones(numbits):
    return (1 << numbits) - 1


def bits_and_bytes(order):
    bits = int(math.log(order - 1, 2) + 1)
    bytes = bits // 8
    extrabits = bits % 8
    return bits, bytes, extrabits


# the following randrange_from_seed__METHOD() functions take an
# arbitrarily-sized secret seed and turn it into a number that obeys the same
# range limits as randrange() above. They are meant for deriving consistent
# signing keys from a secret rather than generating them randomly, for
# example a protocol in which three signing keys are derived from a master
# secret. You should use a uniformly-distributed unguessable seed with about
# curve.baselen bytes of entropy. To use one, do this:
#   seed = os.urandom(curve.baselen) # or other starting point
#   secexp = ecdsa.util.randrange_from_seed__trytryagain(sed, curve.order)
#   sk = SigningKey.from_secret_exponent(secexp, curve)


def randrange_from_seed__truncate_bytes(seed, order, hashmod=sha256):
    # hash the seed, then turn the digest into a number in [1,order), but
    # don't worry about trying to uniformly fill the range. This will lose,
    # on average, four bits of entropy.
    bits, _bytes, extrabits = bits_and_bytes(order)
    if extrabits:
        _bytes += 1
    base = hashmod(seed).digest()[:_bytes]
    base = "\x00" * (_bytes - len(base)) + base
    number = 1 + int(binascii.hexlify(base), 16)
    assert 1 <= number < order
    return number


def randrange_from_seed__truncate_bits(seed, order, hashmod=sha256):
    # like string_to_randrange_truncate_bytes, but only lose an average of
    # half a bit
    bits = int(math.log(order - 1, 2) + 1)
    maxbytes = (bits + 7) // 8
    base = hashmod(seed).digest()[:maxbytes]
    base = "\x00" * (maxbytes - len(base)) + base
    topbits = 8 * maxbytes - bits
    if topbits:
        base = int2byte(ord(base[0]) & lsb_of_ones(topbits)) + base[1:]
    number = 1 + int(binascii.hexlify(base), 16)
    assert 1 <= number < order
    return number


def randrange_from_seed__trytryagain(seed, order):
    # figure out exactly how many bits we need (rounded up to the nearest
    # bit), so we can reduce the chance of looping to less than 0.5 . This is
    # specified to feed from a byte-oriented PRNG, and discards the
    # high-order bits of the first byte as necessary to get the right number
    # of bits. The average number of loops will range from 1.0 (when
    # order=2**k-1) to 2.0 (when order=2**k+1).
    assert order > 1
    bits, bytes, extrabits = bits_and_bytes(order)
    generate = PRNG(seed)
    while True:
        extrabyte = b""
        if extrabits:
            extrabyte = int2byte(ord(generate(1)) & lsb_of_ones(extrabits))
        guess = string_to_number(extrabyte + generate(bytes)) + 1
        if 1 <= guess < order:
            return guess


def number_to_string(num, order):
    l = orderlen(order)
    fmt_str = "%0" + str(2 * l) + "x"
    string = binascii.unhexlify((fmt_str % num).encode())
    assert len(string) == l, (len(string), l)
    return string


def number_to_string_crop(num, order):
    l = orderlen(order)
    fmt_str = "%0" + str(2 * l) + "x"
    string = binascii.unhexlify((fmt_str % num).encode())
    return string[:l]


def string_to_number(string):
    return int(binascii.hexlify(string), 16)


def string_to_number_fixedlen(string, order):
    l = orderlen(order)
    assert len(string) == l, (len(string), l)
    return int(binascii.hexlify(string), 16)


def sigencode_strings(r, s, order):
    """
    Encode the signature to a pair of strings in a tuple

    Encodes signature into raw encoding (:term:`raw encoding`) with the
    ``r`` and ``s`` parts of the signature encoded separately.

    It's expected that this function will be used as a ``sigencode=`` parameter
    in :func:`ecdsa.keys.SigningKey.sign` method.

    :param int r: first parameter of the signature
    :param int s: second parameter of the signature
    :param int order: the order of the curve over which the signature was
        computed

    :return: raw encoding of ECDSA signature
    :rtype: tuple(bytes, bytes)
    """
    r_str = number_to_string(r, order)
    s_str = number_to_string(s, order)
    return (r_str, s_str)


def sigencode_string(r, s, order):
    """
    Encode the signature to raw format (:term:`raw encoding`)

    It's expected that this function will be used as a ``sigencode=`` parameter
    in :func:`ecdsa.keys.SigningKey.sign` method.

    :param int r: first parameter of the signature
    :param int s: second parameter of the signature
    :param int order: the order of the curve over which the signature was
        computed

    :return: raw encoding of ECDSA signature
    :rtype: bytes
    """
    # for any given curve, the size of the signature numbers is
    # fixed, so just use simple concatenation
    r_str, s_str = sigencode_strings(r, s, order)
    return r_str + s_str


def sigencode_der(r, s, order):
    """
    Encode the signature into the ECDSA-Sig-Value structure using :term:`DER`.

    Encodes the signature to the following :term:`ASN.1` structure::

        Ecdsa-Sig-Value ::= SEQUENCE {
            r       INTEGER,
            s       INTEGER
        }

    It's expected that this function will be used as a ``sigencode=`` parameter
    in :func:`ecdsa.keys.SigningKey.sign` method.

    :param int r: first parameter of the signature
    :param int s: second parameter of the signature
    :param int order: the order of the curve over which the signature was
        computed

    :return: DER encoding of ECDSA signature
    :rtype: bytes
    """
    return der.encode_sequence(der.encode_integer(r), der.encode_integer(s))


def _canonize(s, order):
    """
    Internal function for ensuring that the ``s`` value of a signature is in
    the "canonical" format.

    :param int s: the second parameter of ECDSA signature
    :param int order: the order of the curve over which the signatures was
        computed

    :return: canonical value of s
    :rtype: int
    """
    if s > order // 2:
        s = order - s
    return s


def sigencode_strings_canonize(r, s, order):
    """
    Encode the signature to a pair of strings in a tuple

    Encodes signature into raw encoding (:term:`raw encoding`) with the
    ``r`` and ``s`` parts of the signature encoded separately.

    Makes sure that the signature is encoded in the canonical format, where
    the ``s`` parameter is always smaller than ``order / 2``.
    Most commonly used in bitcoin.

    It's expected that this function will be used as a ``sigencode=`` parameter
    in :func:`ecdsa.keys.SigningKey.sign` method.

    :param int r: first parameter of the signature
    :param int s: second parameter of the signature
    :param int order: the order of the curve over which the signature was
        computed

    :return: raw encoding of ECDSA signature
    :rtype: tuple(bytes, bytes)
    """
    s = _canonize(s, order)
    return sigencode_strings(r, s, order)


def sigencode_string_canonize(r, s, order):
    """
    Encode the signature to raw format (:term:`raw encoding`)

    Makes sure that the signature is encoded in the canonical format, where
    the ``s`` parameter is always smaller than ``order / 2``.
    Most commonly used in bitcoin.

    It's expected that this function will be used as a ``sigencode=`` parameter
    in :func:`ecdsa.keys.SigningKey.sign` method.

    :param int r: first parameter of the signature
    :param int s: second parameter of the signature
    :param int order: the order of the curve over which the signature was
        computed

    :return: raw encoding of ECDSA signature
    :rtype: bytes
    """
    s = _canonize(s, order)
    return sigencode_string(r, s, order)


def sigencode_der_canonize(r, s, order):
    """
    Encode the signature into the ECDSA-Sig-Value structure using :term:`DER`.

    Makes sure that the signature is encoded in the canonical format, where
    the ``s`` parameter is always smaller than ``order / 2``.
    Most commonly used in bitcoin.

    Encodes the signature to the following :term:`ASN.1` structure::

        Ecdsa-Sig-Value ::= SEQUENCE {
            r       INTEGER,
            s       INTEGER
        }

    It's expected that this function will be used as a ``sigencode=`` parameter
    in :func:`ecdsa.keys.SigningKey.sign` method.

    :param int r: first parameter of the signature
    :param int s: second parameter of the signature
    :param int order: the order of the curve over which the signature was
        computed

    :return: DER encoding of ECDSA signature
    :rtype: bytes
    """
    s = _canonize(s, order)
    return sigencode_der(r, s, order)


class MalformedSignature(Exception):
    """
    Raised by decoding functions when the signature is malformed.

    Malformed in this context means that the relevant strings or integers
    do not match what a signature over provided curve would create. Either
    because the byte strings have incorrect lengths or because the encoded
    values are too large.
    """

    pass


def sigdecode_string(signature, order):
    """
    Decoder for :term:`raw encoding`  of ECDSA signatures.

    raw encoding is a simple concatenation of the two integers that comprise
    the signature, with each encoded using the same amount of bytes depending
    on curve size/order.

    It's expected that this function will be used as the ``sigdecode=``
    parameter to the :func:`ecdsa.keys.VerifyingKey.verify` method.

    :param signature: encoded signature
    :type signature: bytes like object
    :param order: order of the curve over which the signature was computed
    :type order: int

    :raises MalformedSignature: when the encoding of the signature is invalid

    :return: tuple with decoded ``r`` and ``s`` values of signature
    :rtype: tuple of ints
    """
    signature = normalise_bytes(signature)
    l = orderlen(order)
    if not len(signature) == 2 * l:
        raise MalformedSignature(
            "Invalid length of signature, expected {0} bytes long, "
            "provided string is {1} bytes long".format(2 * l, len(signature))
        )
    r = string_to_number_fixedlen(signature[:l], order)
    s = string_to_number_fixedlen(signature[l:], order)
    return r, s


def sigdecode_strings(rs_strings, order):
    """
    Decode the signature from two strings.

    First string needs to be a big endian encoding of ``r``, second needs to
    be a big endian encoding of the ``s`` parameter of an ECDSA signature.

    It's expected that this function will be used as the ``sigdecode=``
    parameter to the :func:`ecdsa.keys.VerifyingKey.verify` method.

    :param list rs_strings: list of two bytes-like objects, each encoding one
        parameter of signature
    :param int order: order of the curve over which the signature was computed

    :raises MalformedSignature: when the encoding of the signature is invalid

    :return: tuple with decoded ``r`` and ``s`` values of signature
    :rtype: tuple of ints
    """
    if not len(rs_strings) == 2:
        raise MalformedSignature(
            "Invalid number of strings provided: {0}, expected 2".format(
                len(rs_strings)
            )
        )
    (r_str, s_str) = rs_strings
    r_str = normalise_bytes(r_str)
    s_str = normalise_bytes(s_str)
    l = orderlen(order)
    if not len(r_str) == l:
        raise MalformedSignature(
            "Invalid length of first string ('r' parameter), "
            "expected {0} bytes long, provided string is {1} "
            "bytes long".format(l, len(r_str))
        )
    if not len(s_str) == l:
        raise MalformedSignature(
            "Invalid length of second string ('s' parameter), "
            "expected {0} bytes long, provided string is {1} "
            "bytes long".format(l, len(s_str))
        )
    r = string_to_number_fixedlen(r_str, order)
    s = string_to_number_fixedlen(s_str, order)
    return r, s


def sigdecode_der(sig_der, order):
    """
    Decoder for DER format of ECDSA signatures.

    DER format of signature is one that uses the :term:`ASN.1` :term:`DER`
    rules to encode it as a sequence of two integers::

        Ecdsa-Sig-Value ::= SEQUENCE {
            r       INTEGER,
            s       INTEGER
        }

    It's expected that this function will be used as as the ``sigdecode=``
    parameter to the :func:`ecdsa.keys.VerifyingKey.verify` method.

    :param sig_der: encoded signature
    :type sig_der: bytes like object
    :param order: order of the curve over which the signature was computed
    :type order: int

    :raises UnexpectedDER: when the encoding of signature is invalid

    :return: tuple with decoded ``r`` and ``s`` values of signature
    :rtype: tuple of ints
    """
    sig_der = normalise_bytes(sig_der)
    # return der.encode_sequence(der.encode_integer(r), der.encode_integer(s))
    rs_strings, empty = der.remove_sequence(sig_der)
    if empty != b"":
        raise der.UnexpectedDER(
            "trailing junk after DER sig: %s" % binascii.hexlify(empty)
        )
    r, rest = der.remove_integer(rs_strings)
    s, empty = der.remove_integer(rest)
    if empty != b"":
        raise der.UnexpectedDER(
            "trailing junk after DER numbers: %s" % binascii.hexlify(empty)
        )
    return r, s
